#!/usr/bin/env python3
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Generates data and writes it to a file.
"""

import numpy as np
from polygraphy.common import TensorMetadata
from polygraphy.json import save_json
from polygraphy.tools import Tool
from polygraphy.tools.args import DataLoaderArgs


class GenData(Tool):
    # Polygraphy will use the docstring of the tool child class to generate
    # the summary for the command-line help output.
    """
    Generate random data and write it to a file.
    """
    # First, we'll implement the constructor to subscribe to argument groups
    # that we're intrested in.
    # All the argument groups we subscribe to will be stored in a member called
    # arg_groups, which maps types to instances.
    def __init__(self):
        super().__init__()
        self.subscribe_args(DataLoaderArgs())

    # Next, we'll add custom arguments, beyond those provided by our subscribed
    # argument groups, by defining `add_parser_args`.
    def add_parser_args(self, parser):
        parser.add_argument("-o", "--output", help="Path at which to write generated data.", required=True)
        parser.add_argument("--num-values", help="The number of random values to generate.", default=1, type=int)

    # Lastly, we implement `run`, which will implement the functionality of our tool.
    def run(self, args):
        # The DataLoaderArgs argument group provides a helper called `get_data_loader`, which
        # will create a new data loader based on the command-line arguments provided by the user.
        # See `polygraphy/tools/args/data_loader.py` for implementation details.
        #
        # To get data of the shape we want, we'll set the `input_metadata` parameter based on --num-values.
        meta = TensorMetadata().add(name="data", dtype=np.float32, shape=(args.num_values,))
        data_loader = self.arg_groups[DataLoaderArgs].get_data_loader(meta)

        # data_loader behaves like a generator/iterable, so we can cast it to a `list` to
        # generate all the data at once.
        save_json(list(data_loader), dest=args.output, description="randomly generated numbers")


# Now we just use the tool's main() method to make this script behave like a CLI tool.
GenData().main()
